{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "47020873",
   "metadata": {},
   "outputs": [],
   "source": [
    "docs = [\n",
    "    #数字\n",
    "    ['5', '2', '4', '8', '6', '2', '3', '6', '4'],\n",
    "    ['4', '8', '5', '6', '9', '5', '5', '6'],\n",
    "    ['1', '1', '5', '2', '3', '3', '8'],\n",
    "    ['3', '6', '9', '6', '8', '7', '4', '6', '3'],\n",
    "    ['8', '9', '9', '6', '1', '4', '3', '4'],\n",
    "    ['1', '0', '2', '0', '2', '1', '3', '3', '3', '3', '3'],\n",
    "    ['9', '3', '3', '0', '1', '4', '7', '8'],\n",
    "    ['9', '9', '8', '5', '6', '7', '1', '2', '3', '0', '1', '0'],\n",
    "\n",
    "    #字母中夹杂了一些数字\n",
    "    ['a', 't', 'g', 'q', 'e', 'h', '9', 'u', 'f'],\n",
    "    ['e', 'q', 'y', 'u', 'o', 'i', 'p', 's'],\n",
    "    ['q', 'o', '9', 'p', 'l', 'k', 'j', 'o', 'k', 'k', 'o', 'p'],\n",
    "    ['h', 'g', 'y', 'i', 'u', 't', 't', 'a', 'e', 'q'],\n",
    "    ['i', 'k', 'd', 'q', 'r', 'e', '9', 'e', 'a', 'd'],\n",
    "    ['o', 'p', 'd', 'g', '9', 's', 'a', 'f', 'g', 'a'],\n",
    "    ['i', 'u', 'y', 'g', 'h', 'k', 'l', 'a', 's', 'w'],\n",
    "    ['o', 'l', 'u', 'y', 'a', 'o', 'g', 'f', 's'],\n",
    "    ['o', 'p', 'i', 'u', 'y', 'g', 'd', 'a', 's', 'j', 'd', 'l'],\n",
    "    ['u', 'k', 'i', 'l', 'o', '9', 'l', 'j', 's'],\n",
    "    ['y', 'g', 'i', 's', 'h', 'k', 'j', 'l', 'f', 'r', 'f'],\n",
    "    ['i', 'o', 'h', '9', 'n', '9', 'd', '9', 'f', 'a', '9'],\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2fe303ff",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(9, 30)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#生成字典\n",
    "zidian = {}\n",
    "for doc in docs:\n",
    "    for word in doc:\n",
    "        if word not in zidian:\n",
    "            zidian[word] = len(zidian)\n",
    "\n",
    "zidian['0'], len(zidian)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "bcebc1c2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "652"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torch.nn as nn\n",
    "\n",
    "\n",
    "#定义数据\n",
    "class DocDataset(Dataset):\n",
    "    def __init__(self):\n",
    "        xs = []\n",
    "        ys = []\n",
    "\n",
    "        for doc in docs:\n",
    "\n",
    "            #遍历每个句子的每个词,作为x\n",
    "            for i in range(len(doc)):\n",
    "\n",
    "                #遍历当前字的偏移,作为y\n",
    "                for j in [-2, -1, 1, 2]:\n",
    "\n",
    "                    #如果出界了就跳过\n",
    "                    if i + j < 0 or i + j >= len(doc):\n",
    "                        continue\n",
    "\n",
    "                    xs.append(zidian[doc[i]])\n",
    "                    ys.append(zidian[doc[i + j]])\n",
    "\n",
    "        self.xs = torch.LongTensor(xs)\n",
    "        self.ys = torch.LongTensor(ys)\n",
    "\n",
    "    def __getitem__(self, i):\n",
    "        return self.xs[i], self.ys[i]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.xs)\n",
    "\n",
    "\n",
    "len(DocDataset())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "25f1a6cd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([19, 17,  6, 14]),\n",
       " torch.Size([4]),\n",
       " tensor([ 6,  6, 13,  6]),\n",
       " torch.Size([4]))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#数据加载器\n",
    "def get_dataloader():\n",
    "    dataloader = DataLoader(dataset=DocDataset(),\n",
    "                            batch_size=4,\n",
    "                            shuffle=True,\n",
    "                            drop_last=True)\n",
    "    return dataloader\n",
    "\n",
    "\n",
    "for i, data in enumerate(get_dataloader()):\n",
    "    sample = data\n",
    "    break\n",
    "\n",
    "sample[0], sample[0].shape, sample[1], sample[1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5e258165",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[-0.5162,  0.3245,  0.3053,  0.2911, -0.3500,  0.5463, -0.5317, -0.5120,\n",
       "           0.3480,  0.4718, -0.5056,  0.5381,  0.3380,  0.3746, -0.0517, -0.5607,\n",
       "          -0.0462,  0.3976, -0.5185, -0.5205,  0.2119,  0.3325,  0.7214, -0.5070,\n",
       "          -0.5151,  0.3951, -0.1502, -0.5719, -0.7020,  0.1182],\n",
       "         [-0.6179,  0.3763,  0.3163,  0.3189, -0.3094,  0.6617, -0.4516, -0.6138,\n",
       "           0.2164,  0.5931, -0.4071,  0.5344,  0.3486,  0.3853, -0.1799, -0.5097,\n",
       "          -0.0174,  0.3685, -0.5632, -0.6166,  0.2338,  0.4011,  0.6183, -0.4409,\n",
       "          -0.6470,  0.4314, -0.0168, -0.4587, -0.6129,  0.1485]],\n",
       "        grad_fn=<SliceBackward>), torch.Size([4, 30]))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class SkipGram(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        self.embed = torch.nn.Embedding(30, 2)\n",
    "        self.embed.weight.data.normal_(0,0.1)\n",
    "        \n",
    "        self.fc = torch.nn.Linear(2, 30)\n",
    "\n",
    "    def forward(self, x):\n",
    "\n",
    "        #[b] -> [b,2]\n",
    "        x = self.embed(x)\n",
    "\n",
    "        #[b,2] -> [b,30]\n",
    "        x = self.fc(x)\n",
    "\n",
    "        return x\n",
    "\n",
    "\n",
    "model = SkipGram()\n",
    "\n",
    "out = model(sample[0])\n",
    "\n",
    "out[:2], out.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4c12d416",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 tensor(3.3525, grad_fn=<NllLossBackward>)\n",
      "20 tensor(3.2004, grad_fn=<NllLossBackward>)\n",
      "40 tensor(2.5589, grad_fn=<NllLossBackward>)\n",
      "60 tensor(3.3490, grad_fn=<NllLossBackward>)\n",
      "80 tensor(2.5052, grad_fn=<NllLossBackward>)\n",
      "100 tensor(2.6275, grad_fn=<NllLossBackward>)\n",
      "120 tensor(2.9588, grad_fn=<NllLossBackward>)\n",
      "140 tensor(1.5526, grad_fn=<NllLossBackward>)\n",
      "160 tensor(2.6236, grad_fn=<NllLossBackward>)\n",
      "180 tensor(2.2506, grad_fn=<NllLossBackward>)\n"
     ]
    }
   ],
   "source": [
    "import random\n",
    "\n",
    "criteon = torch.nn.CrossEntropyLoss()\n",
    "optim = torch.optim.Adam(model.parameters(), lr=1e-2)\n",
    "\n",
    "model.train()\n",
    "for epoch in range(200):\n",
    "    for i, data in enumerate(get_dataloader()):\n",
    "        x, y = data\n",
    "        \n",
    "        optim.zero_grad()\n",
    "\n",
    "        #计算\n",
    "        #[b] -> [b,30]\n",
    "        out = model(x)\n",
    "\n",
    "        loss = criteon(out, y)\n",
    "        loss.backward()\n",
    "        optim.step()\n",
    "\n",
    "    if epoch % 20 == 0:\n",
    "        print(epoch, loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3718d127",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.collections.PathCollection at 0x7f56cd0397f0>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAe6UlEQVR4nO3de5hU1ZX38e9qaGiae6QRwkUcZYziFVtEmTEYNQFFGCdGiVHUqKgBnzgmk2DU5DGa+E7eqPE2EmIUMVGDisIkIGLE66ihMYICQtAodgBp8cKlGxro9f6xi7eb7tPQUNV1qur8Ps9TT1edc6i9LHDV7n32XtvcHRERKXxFcQcgIiLZoYQvIpIQSvgiIgmhhC8ikhBK+CIiCdE27gB2p0ePHj5gwIC4wxARyRsLFy782N3Los7ldMIfMGAAFRUVcYchIpI3zOyD5s6lPaRjZv3MbL6ZLTOzJWb23YhrzMzuNLOVZrbYzAan266IiOydTPTwtwPfc/c3zKwzsNDM5rn70gbXjAQGph7HA/emfoqISJak3cN39zXu/kbq+UZgGdCn0WVjgGkevAZ0M7Pe6bYtIiItl9FZOmY2ADgGeL3RqT7Ahw1eV9L0S2Hne4w3swozq6iqqspkeCIiiZaxhG9mnYAngKvdfUPj0xF/JLKIj7tPcfdydy8vK4u80Swiu7FsGbz0EmzaFHckkmsykvDNrJiQ7H/v7jMiLqkE+jV43RdYnYm2RSRYvRoGD4bychg1CvbfH+66K+6oJJdkYpaOAb8Flrn7bc1cNgsYl5qtMxT43N3XpNu2iNQbNQoWL4bqatiwIfycNAnmz487MskVmZilMwy4AHjLzN5MHfsR0B/A3ScDs4HTgZVANXBxBtoVkZR33oHly2HHjl2PV1fDr34FJ58cS1iSY9JO+O7+MtFj9A2vcWBCum2JSLT166FtM/83r12b3Vgkd6mWjkgBOPpo2L696fGSEjjzzKyHIzlKCV+kAHTsCL/4BZSW1h8rKQk3bidOjC8uyS05XUtHRFpuwgQYNCiM2a9dC2ecAVddBd26xR2Z5AolfJECMnx4eIhE0ZCOiEhCKOGLiCSEEr6ISEIo4YuIJIQSvohIQijhi4gkhBK+iEhCaB6+SJ5591149VXo1SsURWvTJu6IJF8o4Yvkibo6uOwyePjhUCjNDLp3h+efhwMPjDs6yQca0hHJEw89BI8+Clu2hN2sNm6Eyko466y4I5N8oYQvkifuuSfUt2+org5WrIC//z2emCS/KOGL5InNm6OPuzf9IhCJooQvkieaq2u/ZQt06ZLdWCQ/KeGL5Ik+fcKN2saKi+Gxx7Ifj+QfJXyRPFFbG72N4fbt8NlnWQ9H8pASvkie+NrXohN+aSmMHJn9eCT/KOGL5InDD4eLLgrbGe7UsSOMGQNDh8YWluQRLbwSySP33AOjR8PUqbBjB5x/friZGzW2L9KYEr5IHjGDESPCQ2RvaUhHRCQhMpLwzex+M1tnZm83c364mX1uZm+mHj/ORLsikl3vvhumgC5YEBZ8SX7J1JDOVOBuYNpurnnJ3UdlqD0RyaIdO2DcOJgxA9q1C68HDoR586BHj7ijk5bKSA/f3V8EPsnEe4lI7rnjDnjqqbCqd8OGUOZhyRK48MK4I5O9kc0x/BPMbJGZzTGzQc1dZGbjzazCzCqqqqqyGJ6INCeqcNu2bfDss+ELQPJDthL+G8AB7n4UcBfwVHMXuvsUdy939/KysrIshSciu7NpU/RxM6ipyW4ssu+ykvDdfYO7b0o9nw0Um5lG/kTyxJlnRq/y7d8fevbMfjyyb7KS8M2sl1lYGmJmQ1Ltrs9G2yKSvptuCjdnS0vD63btwirf++/Xoq98kpFZOmb2CDAc6GFmlcBPgGIAd58MnA1caWbbgRpgrLsmdYnki969Ydky+O1v4cUX4ZBD4DvfgQED4o5M9oblct4tLy/3ioqKuMMQEckbZrbQ3cujzmmlrYhIQijhi2TJ5s1w661w4olwxhkwZ07cEUnSqHiaSBZUV8OQIWGz8Z3TGF94AX7wA/ixCo1IlqiHL5IF06bB++/vOmd982a45RbQ+kLJFiV8kSz4n/9pulIVwvTGV1/NfjySTEr4IlnQuzcURfzf5q7iY5I9SvgiWTBhApSU7HqsqAj22w9OOCGemCR5lPBFsuCYY+Dee6FTJ+jSJaxSHTgwFB/TSlXJFs3SEcmScePgG9+AhQuha9ewKbmSvWSTEr5IK3OHZ56BP/wBOnSAiy6CI46IOypJIiV8kVbkDt/6FsyaFaZhFhXB1Klw/fVw7bVxRydJozF8kVY0f359sgeoqwvTM3/6U6isjDc2SR4lfJFWNGNGfbJvqE0bePrp7McjyaaEL9KKOnWK3jjELMzUEckmJXyRVjRuHBQXNz3uDqNGZT8eSTYlfJFWdNhhcNttYdFV587h0bEjPPVUeC6STZqlI9LKrrgCzj47TM0sKYGvfU3DORIPJXyRLOjRA847L+4oJOk0pCMikhBK+CIiCaGELyKSEEr4IiIJoYQvIpIQSvgiIgmRkYRvZveb2Toze7uZ82Zmd5rZSjNbbGaDM9GuiMRn2zZ49FH45jdh4kRYtCjuiGRPMjUPfypwNzCtmfMjgYGpx/HAvamfIpKHamvhlFPgr38NxeHatIH774e77oJLLok7OmlORnr47v4i8MluLhkDTPPgNaCbmfXORNsikn2PPFKf7AF27ICaGrjqKti4Md7YpHnZGsPvA3zY4HVl6lgTZjbezCrMrKKqqiorwYnI3pk+Pbrsc3ExvPxy9uORlslWwo/audOjLnT3Ke5e7u7lZWVlrRyWiOyLLl2ij7urTlAuy1bCrwT6NXjdF1idpbZFJMMuvxxKS5se79gRhg3LfjzSMtlK+LOAcanZOkOBz919TZbaFpEMGz4cfvSj+rLPXbpAWRnMmRNu4EpuysgsHTN7BBgO9DCzSuAnQDGAu08GZgOnAyuBauDiTLQrIvG57jq49FJ44QXo2hW+8pXozV4kd2Qk4bv7N/dw3oEJmWhLRHLH/vvDOefEHYW0lFbaiogkhBK+iEhCKOGLiCSEEr6ISEIo4YuIJIQ2MReRnFRdDdOmwbx5MGAAXHEFDBwYd1T5TQlfRHLO55/DccfB6tWhZk9xMUyeDI8/DiNHxh1d/tKQjojknF/+Elatqi/Qtm1b6PFfeGGozCn7RglfRHLO44/D1q1Nj9fUwPLl2Y+nUCjhi0jO6dw5+vj27dCpU3ZjKSRK+CKSc666qmmZ5TZtYNAg6N8/npgKgRK+iOSc88+HCy6or8bZqVOYqTNjRtyR5TfN0hGRnGMG994LkybB669D796hzn6RuqhpUcIXkZx1wAHhIZmhhF9oNm6EdeugXz9o1y7uaKQVrVoFv/sdfPIJjBgBp5wSesYizdEvSIWitjbsRtGzJxx1FPToAXfcEXdU0kpmzYIvfQluvBFuvRXOOgvOPFNz1GX3lPALxdVXw8MPw5YtYbXKxo1hD7rHHos7MsmwLVvCTc2amvA9D7BpEzz/PEyfHmtokuOU8AtBTQ1MnRp+NlRdDTffHEtI0npefjl66GbzZnjooezHI/lDCb8QfPZZ8+dWr85aGJIdu9s3VrdtZHeU8AtBz55NV6lA6AYOHZr9eKRVDRsWnfQ7dgy3cUSao4RfCNq0gdtug9LS+mNmIQP8/OfxxSWtom1bmDmzfkFSSQl06ADjxsEZZ8QdneQyTcssFBdcEHr6N90E778fevY33hjWokvBGTYsjNbNnAmffgqnnhpm7Yjsjrl73DE0q7y83CsqKuIOQ0Qkb5jZQncvjzqnIR0RkYTISMI3sxFmttzMVprZpIjzw83sczN7M/X4cSbaFRGRlkt7DN/M2gD3AKcBlcACM5vl7ksbXfqSu49Ktz0REdk3mejhDwFWuvt77l4LPAqMycD7iohIBmUi4fcBPmzwujJ1rLETzGyRmc0xs2anjpjZeDOrMLOKqqqqDIQnUjjc4aWXYNo0eOutuKORfJOJaZlR9fkaT/15AzjA3TeZ2enAU8DAqDdz9ynAFAizdDIQn0hBqKqCk0+GDz4Iid8dhg+HJ5/UCltpmUz08CuBfg1e9wV2Wc/v7hvcfVPq+Wyg2Mx6ZKBtkcS4+OKwgfemTaFuTnU1zJ8Pt9wSd2SSLzKR8BcAA83sQDNrB4wFZjW8wMx6mYVyT2Y2JNXu+gy0LZII1dXwzDNhE++GampgypR4YpL8k/aQjrtvN7OJwFygDXC/uy8xsytS5ycDZwNXmtl2oAYY67m84kskx2zb1vy5LVuyF4fkt4yUVkgN08xudGxyg+d3A3dnoi2RJOraFQ49FBYv3vV427YwenQ8MUn+0UpbkTzxwAOhYFpJSXhdWhrKJ2kMX1pKxdNE8sTgwbBiBdx3H7zzDpx4YqiZ17lz3JFJvlDCF8kjvXrB9dfHHYXkKw3piIgkhBK+iEhCKOGLiCSEEr6ISEIo4YuIJIQSvshecoetW8NPkXyihC+yF556Cv7pn8Kip+7d4Wc/g7q6uKNKz7Zt8JOfQFlZ+O8aOTLM85fCo3n4Ii307LPwrW+FQmYAn38OP/95KGB2883xxpaOCy6AWbPCfwfA3LkwdCgsXQpf/GK8sUlmqYcv0kI//nF9st+puhp+9aswxJOPPvgAZs6sT/YQhqpqauDOO+OLS1qHEr5IC61cGX28rg4+/ji7sWTK0qXQvn3T47W18Prr2Y9HWpcSvtSrrQ2D1FOmaBA3wuGHRx8vLg5FzPLRwIHhr72x4mI48sjsxyOtSwlfgiVLoG9fGDcO/uM/QqWub387/+9IZtDNN0OHDrseKy2FG24ICTIfHXwwfPnL9RU4d2rfHq6+OpaQpBUp4UsYtB09OmyaunFjGJiuqYHp0+GRR+KOLmeceCLMmQPl5SEhHnAA3HEHfO97cUeWnieeCN/zJSVQVBS+6597Dg48MO7IJNMslzeeKi8v94qKirjDKHxvvx2mZWze3PTcv/wLvPRS9mOSrKurC1soakP0/GZmC929POqcevgC773X/Coi7Z+XVZ9+CtdcA/36hfn+t9wSPcbeGoqKlOwLnebhJ9mqVfBv/xZu0Dacl7dThw5w/vlZDyuptmyB448PUyV3JvmbboIXXwxDSSLpUg8/qdzhK18Jm6RGJftOneCII+Dyy7MfW0I99hisXr1rj76mJiT8hQvji0sKhxJ+Uv3v/8JHH8GOHbseLyqCQYNg6lR45ZWm0zek1bzySvRtFHfQrSzJBCX8pFq7FsyaHq+rC3P1vv51aKsRv2w66KDo79e2baF//+zHI4VHCT+pjj8+VM1qrLQURozIfjzCRRc1nc/fpk0o0vbVr8YSkhQYJfyk6ts3jM937Fh/rKQE+vQJ1bQk68rKYP58OPTQMM+/XbswW/all0LiF0lXRhK+mY0ws+VmttLMJkWcNzO7M3V+sZkNzkS7kqbbb4f77gsrigYNgkmTYMGCXb8EJKuOPTbUt3n//XAD9+WXNZwjmZP2IK2ZtQHuAU4DKoEFZjbL3Zc2uGwkMDD1OB64N/VT4mQGY8eGh+SUXr3ijkAKUSZ6+EOAle7+nrvXAo8CYxpdMwaY5sFrQDcz652BtkVEpIUykfD7AB82eF2ZOra31wBgZuPNrMLMKqqqqjIQnoiIQGYSfsTcPhqv02/JNeGg+xR3L3f38rKysrSDE0Jh83PPhRNOCLt45GvxdhFJSyYmWlcC/Rq87gus3odrpDU88ghcemlYsukOf/1rqHe/aBHsv3/c0YlIFmWih78AGGhmB5pZO2AsMKvRNbOAcanZOkOBz919TQbalt3Ztg0mTAjljncWR9u6FT75JFTlEpFESTvhu/t2YCIwF1gGTHf3JWZ2hZldkbpsNvAesBL4DfCddNuVFvjb36IXV23bBn/6U/bjEZFYZWTtvLvPJiT1hscmN3juwIRMtCV7oXv3UOA8iu6PiCSOVtoWst69w6Kqxuv1O3bM/22aCpQ7vPlm2HFqw4a4o5FCo4Rf6P7wh7AnX4cO0LVrKJ/wn/8ZiqNJTvngAzjssLDJ2FlnhcVXd94Zd1RSSFQOsdD16BFKIa9YAWvWwFFHQbducUcljbiHmnUrVuy6b/y118LRR8NJJ8UWmhQQ9fCT4p//Gb78ZSX7HLV4MXz44a7JHsIEq7vuiicmKTxK+CI5YP365itirl1b//yNN+Ccc8JmZJdcAitXZic+KQwa0hHJAccdFz2DtkOHsO0wwLx54fnONXTLlsH06WHE7ogjshmt5Cv18EVyQOfO8ItfhP1ndm5E1qFD2LZg/PiQ4K+8ctc1dDt2wKZN4R68SEuohy+SIyZOhCOPDGP2H30EY8aEPWo6dQqJ/YMPov/cK69kN07JX0r4IjnkpJOiZ+SUlITlFFHr6Pbbr/XjksKgIR2RPNC2LVx8cRjmaai0VGvopOWU8EXyxG23hZu27dtDly6h13/FFWEoSKQlNKQjkifat4eHHw7j+6tWwcCBWlYhe0cJXyTP7L+/tjKQfaMhHRGRhFDCFxFJCCV8EZGEUMKXwB3+8Q8VYRcpYEr4AnPnQv/+cPDBYSesMWPgs8/ijkpEMkwJP+nefhv+/d+hshK2bIHaWnj6aRg9Ou7IJEPefz/sd9OpE/TsCddfH/6aJXk0LTPpbr8dtm7d9VhtLVRUwPLlcMgh8cQlGfHxx2HDs08/DbX2N28OC7jeegtmzow7Osk29fCTbsWKUHaxsXbtmq/WJXnj178OSb7hxio1NaHU8vLl8cUl8VDCT7rhw8MSzsa2bg3bIUpe+8tfwkhdY8XFoZcvyaKEn3RXXRWKsTfcbqm0FC67TMs5C8Dhh4df1hrbvj3co5dkSSvhm9kXzGyemf0t9bN7M9e9b2ZvmdmbZlaRTpuSYT17hn3zLrgAevWCQw8N4/p33BF3ZJIBV17Z9Be4du3CxuhHHx1HRBKndHv4k4A/u/tA4M+p18052d2PdvfyNNuUTOvXDx54ANasgaVLwxZLO7ddkrzWty88/3y4cdumTUj255wDc+bEHZnEId1ZOmOA4annDwLPAz9M8z1FJIMGD4YFC8JYftu24SHJlG4Pf393XwOQ+tmzmesceMbMFprZ+N29oZmNN7MKM6uoqqpKM7w85A6zZ8OZZ8LJJ4dpFo2nTYrsg5ISJfuk2+Nfv5k9C/SKOHXdXrQzzN1Xm1lPYJ6ZvePuL0Zd6O5TgCkA5eXlvhdtFIZJk+Cee8JcOgjTLB58EF54IUytEBHZR3tM+O5+anPnzOwjM+vt7mvMrDewrpn3WJ36uc7MngSGAJEJP9E+/DDcLG3Yo6+uhsWLYcYMOPfc+GITkbyX7pDOLODC1PMLgSZr98yso5l13vkc+CrwdprtFqbmevGbN8Mf/5j9eESkoKSb8P8PcJqZ/Q04LfUaM/uimc1OXbM/8LKZLQL+AvzJ3Z9Os93C9IUvQFHEX0nbtpoTLyJpS+sWjruvB06JOL4aOD31/D1ASzZb4rTTole9FhfDpZdmPx4RKShaaZtLiovhz38Ok6c7dYIuXcLPBx6AL30p7uhEJM9pklauOeKIULSsoiLcsD3+eOjQIe6oRKQAKOHnoqIiGDIk7ihEpMBoSEckh9TVwbp1WmsnrUMJXyRHTJsW6tf17x8mbF1zTahqKZIpGtIRyQGzZ4fKltXV9cd+/euQ8O+8M/33dw+Ltp9/HvbbD84+G7p1S/99Jb+Ye+5WLygvL/eKClVTlsI3ZEgocNZYhw6wfn169+137ICxY8OXSm1tmPlrFrYuHjZs399XcpOZLWyuKrGGdERyQHO7SZqFhJ+ORx8N5ZCrq8NvDJs3w6ZNYe/6qN0tpXAp4YvkgGOPjd6CoF279BdZ//a39bX4GqqpCbN/JTmU8EVywM9+1nTYprQUbr45/SKpuxu1zeERXWkFSvgiOeCYY+DFF0N1je7dw/q7Bx+ECRPSf++LLoKOHZseb9cOjjsu/feX/KFZOiI54thj4ZlnMv++558PTzwBzz0XxvFLSsLavscf33Xveil8SvgiBa5NG5g5E15+GebPD9Myx44NPyX3rFsX7rkMGJD5raWV8EUSwAz+9V/DQ3LT2rXhi/i118KX9Be+AFOnwilN6hHvOyV8EZGYucOpp8Ly5fWrq6urYfRoWLQIDj44M+3opq2ISMxefz2sxWhcSmPbNvjv/85cO0r4IiIxq6yMHq/ftg3efTdz7WhIp1DU1YW18o8/HiZwX3xxmPYhIjmvvDwk98ZKSzM7hq8efiFwh298A849N+yOde+9cNJJ8Mtfxh1Z4XOHGTPC3dBBg+Daa9OvhSCJM2AAnHdeSPA7tWsHPXqEdRSZouJpheDpp0PC37Rp1+MlJfD3v4eau9I6brgBbr+9vnZB+/ahFsLixdC1a7yxSV6pq4Pf/Abuvhs2bAi1jq67LiT9vaHiaYVuxoymyR6gbdvWWckjwccfh9+iGhaq2boVqqpg8uT44pK8VFQEl18Ob70VbuDefvveJ/s9tpHZt5NYdOoUvWTSLHpNvWTGwoWhR99YTU34rUskxyjhF4KLLopOPAAjR+7be65fD7feCpdeClOmRJdbTLpevaK3pCoqCttWieQYJfxCcOSR8F//FcbsO3euf8yatetdoJZaujSs9LjhhlBb95pr4JBDYM2azMeez448MnxObRtNdispge9+N56YRHYjrYRvZt8wsyVmVmdmkTcJUteNMLPlZrbSzCal06Y0Y+JEWLUq7Is3bRp89BEMH75v7/Xtb8Pnn4ehCQi9+48+gh/+MGPhFgSzsLNIeXn9l223bnD//TB4cFZDeeWVMBP3618PG55oL1yJktYsHTM7FKgDfg18392bTKkxszbACuA0oBJYAHzT3Zfu6f01SycGNTUhcUVthdS1K3z2WdZDygurVsGnn8Jhh6VfwH4v3XJLqJtfUxNmiXbsCCecEG4jqBpm8rTaLB13X+buy/dw2RBgpbu/5+61wKPAmHTalVZUVBQeUdq1y24s+aR/fzjqqKwn+7Vr4ac/DXVXdvbdNm+GV18NFTJFGsrGGH4f4MMGrytTxyKZ2XgzqzCziqqqqlYPThpp3z7c6I0al87kChDJiPnzo79jNm8Os3VFGtpjwjezZ83s7YhHS3vpURWdmx1Hcvcp7l7u7uVlZWUtbEIy6r77YODAMN2ztDSMEQwZAjfeGHdk0kiXLtE1WIqKQnldkYb2WEvH3U9Ns41KoF+D132B1Wm+Z7QdO2D27HAjrUcPuPBCOOigVmmqoJWVwdtvwwsvwMqVYajiuOMyvxuDpO3UU6PH6UtK4JJLsh+P5LZsFE9bAAw0swOBfwBjgfMy3sr27WEo4rXXwqrT4uKwCvKhh8LUBdk7RUVw8snhITmrfXuYOzf806+tDd/JtbVw223he1qkoXSnZZ5lZpXACcCfzGxu6vgXzWw2gLtvByYCc4FlwHR3X5Je2BEefjjcqdpZYmDbtjBt4eKLYcuWjDcnkiuOOy4skZgxI8zIXbMmLNEXaSytHr67Pwk8GXF8NXB6g9ezgdnptLVHv/td9GpQszBJOZM1RkVyTHFxGN4R2Z3CWWnboUP0cfcwoCkiknCFk/DHj48uFNahAwwdmv14RERyTOEk/NNPh8suC7350tL6Ze5//KOWG4qIUEhbHJqFAtITJ8Jzz0H37nDGGc0P9YiIJEzhJPydDjpIc+9FRCIUzpCOiIjslhK+iEhCKOGLiCSEEr6ISEIo4YuIJERaO161NjOrAj7Yiz/SA/i4lcIpFPqMdk+fz57pM9qzOD+jA9w9srZ8Tif8vWVmFc1t7SWBPqPd0+ezZ/qM9ixXPyMN6YiIJIQSvohIQhRawp8SdwB5QJ/R7unz2TN9RnuWk59RQY3hi4hI8wqthy8iIs1QwhcRSYiCTfhm9n0zczPrEXcsucTM/q+ZvWNmi83sSTPrFndMucLMRpjZcjNbaWaT4o4n15hZPzObb2bLzGyJmX037phykZm1MbO/mtkf446lsYJM+GbWDzgNWBV3LDloHnC4ux8JrACujTmenGBmbYB7gJHAYcA3zeyweKPKOduB77n7ocBQYII+o0jfBZbFHUSUgkz4wO3ADwDdkW7E3Z9x9+2pl68BfeOMJ4cMAVa6+3vuXgs8CoyJOaac4u5r3P2N1PONhKTWJ96ocouZ9QXOAO6LO5YoBZfwzWw08A93XxR3LHng28CcuIPIEX2ADxu8rkTJrFlmNgA4Bng95lByza8Inc26mOOIlJc7XpnZs0CviFPXAT8CvprdiHLL7j4fd5+ZuuY6wq/ov89mbDnMIo7pN8QIZtYJeAK42t03xB1PrjCzUcA6d19oZsNjDidSXiZ8dz816riZHQEcCCwyMwjDFW+Y2RB3X5vFEGPV3Oezk5ldCIwCTnEtxNipEujX4HVfYHVMseQsMysmJPvfu/uMuOPJMcOA0WZ2OlACdDGz37n7+THH9f8V9MIrM3sfKHd3VfZLMbMRwG3Al929Ku54coWZtSXcxD4F+AewADjP3ZfEGlgOsdCLehD4xN2vjjmcnJbq4X/f3UfFHMouCm4MX/bobqAzMM/M3jSzyXEHlAtSN7InAnMJNyOnK9k3MQy4APhK6t/Om6nerOSJgu7hi4hIPfXwRUQSQglfRCQhlPBFRBJCCV9EJCGU8EVEEkIJX0QkIZTwRUQS4v8BrKyRi4UvdeYAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "colors = []\n",
    "idxs = []\n",
    "for word, idx in zidian.items():\n",
    "    idxs.append(idx)\n",
    "    if word in '1234567890':\n",
    "        colors.append('red')\n",
    "        continue\n",
    "    colors.append('blue')\n",
    "\n",
    "embed = model.embed(torch.LongTensor(idxs))\n",
    "embed = embed.detach().numpy()\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "plt.scatter(embed[:, 0], embed[:, 1], c=colors)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
